package com.gjxx.common.utils;

import com.gjxx.common.annotation.ClassScanner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import org.springframework.core.io.support.ResourcePatternResolver;
import org.springframework.core.type.classreading.CachingMetadataReaderFactory;
import org.springframework.core.type.classreading.MetadataReader;
import org.springframework.core.type.classreading.MetadataReaderFactory;
import org.springframework.core.type.filter.AnnotationTypeFilter;
import org.springframework.core.type.filter.TypeFilter;

import java.io.IOException;
import java.lang.annotation.Annotation;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;

/**
 * 类扫描工具
 * @author zhaxc
 */
public class ClasspathPackageScanner {

	protected final Logger logger = LoggerFactory.getLogger(getClass());

	private static final String RESOURCE_PATTERN = "/**/*.class";

	private ResourcePatternResolver resourcePatternResolver = new PathMatchingResourcePatternResolver();

	private Set<String> packageSet = new HashSet<>();

	private Set<TypeFilter> typeFilterSet = new HashSet<>();

	private Set<Class<?>> classSet = new HashSet<>();

	/**
	 * @param packagesToScan 指定包集合
	 * @param annotationFilters 扫描指定的注解
	 */
	public ClasspathPackageScanner(String[] packagesToScan, @SuppressWarnings("unchecked") Class<? extends ClassScanner>... annotationFilters) {

		if (packagesToScan != null) {
			packageSet.addAll(Arrays.asList(packagesToScan));
		}

		if (annotationFilters != null) {
			for (Class<? extends Annotation> annotation : annotationFilters) {
				typeFilterSet.add(new AnnotationTypeFilter(annotation, false));
			}
		}
	}

	/**
	 * 获取到符合条件的bean
	 * @return
	 * @throws IOException
	 * @throws ClassNotFoundException
	 */
	public Set<Class<?>> getClasses() throws IOException, ClassNotFoundException {

		this.classSet.clear();

		if ( this.packageSet.isEmpty()) {
			return this.classSet;
		}

		for (String pkg : this.packageSet) {

			String pattern = ResourcePatternResolver.CLASSPATH_ALL_URL_PREFIX + org.springframework.util.ClassUtils.convertClassNameToResourcePath(pkg) + RESOURCE_PATTERN;
			Resource[] resources = this.resourcePatternResolver.getResources(pattern);
			MetadataReaderFactory readerFactory = new CachingMetadataReaderFactory(this.resourcePatternResolver);

			for (Resource resource : resources) {
				if (resource.isReadable()) {
					MetadataReader reader = readerFactory.getMetadataReader(resource);
					String className = reader.getClassMetadata().getClassName();
					if (matchesEntityTypeFilter(reader, readerFactory)) {
						this.classSet.add(Class.forName(className));
					}
				}
			}
		}
		if (logger.isInfoEnabled()) {
			for (Class<?> clazz : this.classSet) {
				logger.info(String.format("Found class:%s", clazz.getName()));
			}
		}
		return this.classSet;
	}

	/**
	 * 检查当前扫描到的Bean含有任何一个指定的注解标记
	 *
	 * @param reader
	 * @param readerFactory
	 * @return
	 * @throws IOException
	 */
	private boolean matchesEntityTypeFilter(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException {

		if (!this.typeFilterSet.isEmpty()) {
			for (TypeFilter filter : this.typeFilterSet) {
				if (filter.match(reader, readerFactory)) {
					return true;
				}
			}
		}
		return false;
	}

}
